'''
Author: Zhuangyu
Date: 2022-09-28 11:22:16
LastEditTime: 2022-09-30 01:42:26
'''
from __future__ import print_function, division # Grab some 
import numpy as np
from casadi import *
#This module contains user defined functions that represent the 
#van-Genuchten constitutive relations.

def thetaFun(psi,pars):
    #Dependence of theta[volumetric moisture content] on the h[pressure head]
    # Se=if_else(psi>=0.,1.,(1+np.fabs(psi*pars['alpha'])**pars['n'])**(-pars['m']) )
    Se=if_else(psi>=0.,1.,(1+fabs(psi*pars['alpha'])**pars['n'])**(-pars['m']) )
    # Se=(1+np.fabs(psi*pars['alpha'])**pars['n'])**(-pars['m'])
    # if psi>=0.:
    #     Se = 1.
    # else:
    #     Se = (1+np.fabs(psi*pars['alpha'])**pars['n'])**(-pars['m'])
    theta=pars['thetaR']+(pars['thetaS']-pars['thetaR'])*Se
    return theta

def CFun(psi,pars):
    #Dependence of C[the capillary capacity] on the h[pressure head]
    # Se=if_else(psi>=0.,1.,(1+np.fabs(psi*pars['alpha'])**pars['n'])**(-pars['m']) )
    # Se=np.where(psi>=0.,1.,(1+np.fabs(psi*pars['alpha'])**pars['n'])**(-pars['m']) )
    # Se=(1+np.fabs(psi*pars['alpha'])**pars['n'])**(-pars['m'])    
    Se=if_else(psi>=0.,1.,(1+fabs(psi*pars['alpha'])**pars['n'])**(-pars['m']) )
    dSedh=pars['alpha']*pars['m']/(1-pars['m'])*Se**(1/pars['m'])*(1-Se**(1/pars['m']))**pars['m']
    C=Se*pars['Ss']+(pars['thetaS']-pars['thetaR'])*dSedh
    return C

def KFun(psi,pars):
    #Dependence of K[unsaturated hydraulic conductivity]on the h[pressure head]
    # Se=if_else(psi>=0.,1.,(1+np.fabs(psi*pars['alpha'])**pars['n'])**(-pars['m']))
    # Se=np.where(psi>=0.,1.,(1+np.fabs(psi*pars['alpha'])**pars['n'])**(-pars['m']))
    # Se= (1+np.fabs(psi*pars['alpha'])**pars['n'])**(-pars['m'])
    Se=if_else(psi>=0.,1.,(1+fabs(psi*pars['alpha'])**pars['n'])**(-pars['m']))
    K=pars['Ks']*Se**pars['neta']*(1-(1-Se**(1/pars['m']))**pars['m'])**2
    return K  

def hFun(theta, pars):  
    #Dependence of h[pressure head]on the theta[volumetric moisture content]
    h = (((((theta - pars['thetaR']) / (pars['thetaS'] - pars['thetaR'] + 1.e-20) + 1.e-20) ** (1. / (-(1-1/(pars['n']+1.e-20)) + 1.e-20))
              - 1) + 1.e-20) ** (1. / (pars['n'] + 1.e-20))) / (-pars['alpha'] + 1.e-20)
    return h